# -*- coding: utf-8 -*-

import numpy as np
import pandas as pd
from beartype import beartype
import statsmodels.api as sm
from statsmodels.formula import api as smf
from typing import Union
from dramkit.gentools import (isnull,
                              check_list_arg,
                              raise_warn,
                              raise_error)


@beartype
def lr_smf_df(df: pd.DataFrame,
              ycol: str,
              xcols: Union[list, tuple, set, str] = None,
              intercept: bool = True,
              check_data: bool = False,
              **kwargs):
    '''
    用 `statsmodels.formula` 拟合线性回归
    
    Parameters
    ----------
    df : pd.DataFrame
        数据集
    ycol : str
        因变量列名称
    xcols : str, list, tuple
        自变量列名列表
    intercept : bool
        回归是否带截距项

    Examples
    --------
    >>> df = pd.DataFrame({'x': [1, 2, 3, 4, 5],
    >>>                    'y': [2, 3, 4, 5, 6],
    >>>                    #'y': [2, 4, 6, 8, 10]
    >>>                   })
    >>> xcols = ['x']
    >>> ycol = 'y'
    >>> mdl1 = lr_smf_df(df, ycol, xcols, intercept=True)
    >>> mdl2 = lr_smf_df(df, ycol, xcols, intercept=False)
    >>> print(mdl1.params)
    >>> print(mdl2.params)
    '''
    
    xcols = check_list_arg(xcols, allow_none=True)
    if isnull(xcols):
        xcols = [x for x in df.columns if x != ycol]
    xcols = [x for x in xcols if x != ycol]
    
    if df.shape[0] < len(xcols):
        raise_warn('DataTooLessWran',
                   '样本数量小于自变量个数！',
                   **kwargs)
    
    if check_data:
        # 数值检查
        try:
            df[xcols+[ycol]].astype(float)
        except:
            raise_error('DataTypeError', '数据中有非数值数据！',
                        **kwargs)
        
        # 无穷值检查
        if (df[xcols+[ycol]] == np.inf).sum().sum() > 0 or \
           (df[xcols+[ycol]] == -np.inf).sum().sum() > 0:
            raise_error('DataInfError', '数据中有无穷值！',
                        **kwargs)
        
        # 样本量检查
        notna = (~df[xcols+[ycol]].isna()).sum(axis=1)
        notna = notna[notna == len(xcols+[ycol])]
        if len(notna) < len(xcols):
            raise_warn('DataTooLessWran',
                       '数据完整的样本数量小于自变量个数！',
                       **kwargs)
            
        # 无效值检查
        if df[xcols+[ycol]].isna().sum().sum() > 0:
            raise_warn('DataNaNWarning', '数据中存在无效值。',
                       **kwargs)
    
    if intercept: # 是否有截距
        formula = '{} ~ {} + 1'.format(ycol, ' + '.join(xcols))
    else:
        formula = '{} ~ {} - 1'.format(ycol, ' + '.join(xcols))
    mdl = smf.ols(formula, df).fit() # 最小二乘法
    
    return mdl


def lr_fit_sm(X, y):
    '''
    statsmodels线性回归拟合
    '''

    X = sm.add_constant(X) # 添加截距项
    mdl = sm.OLS(y, X).fit()

    Params = mdl.params
    R2, R2adj = mdl.rsquared, mdl.rsquared_adj

    y_pre = mdl.predict(X)

    return Params, R2, R2adj, y_pre, mdl


class LRClassifierBinary(object):
    '''
    | 逻辑回归二分类
    | 
    | 记输入为X，输出为y，则逻辑回归二分类模型表达式为：
    |     y` = Sigmoid(X * W  + b)，其中y`表示模型预测值，
    |     将截距项b并入W（在X中添加常数项变量），简写为：
    |     y` = Sigmoid(X * W)
    |     其中Sigmoid函数为: Sigmoid(x) = 1.0 / (1 + np.exp(-x))
    | 代价函数为：
    |     Cost = y * ln(y`) + (1-y) * ln(1-y`)
    |     （分类正确代价为0，分类错误代价无穷大，由极大似然估计得来）
    | 转为最小化问题后总代价函数为：
    |     Cost = -Sum(y * ln(y`) + (1-y) * ln(1-y`))
    |
    | 梯度下降法：
    | 根据上面三个方程，用链式求导法则可得到Cost对W的导数：
    |     J = X * (y` - y)
    | J即为梯度下降法中的梯度
    |
    | 牛顿法：待补充
    | 
    | 参考：
    | https://www.cnblogs.com/loongofqiao/p/8642045.html
    '''

    def __init__(self, opt_method='gd', max_iter=1000, lr=0.01):
        '''
        Parameters
        ----------
        opt_method : str
            优化算法设置，默认梯度下降'gd'，牛顿法为'newton'或'nt'
        max_iter : int
            最大迭代次数
        lr : float
            学习速率
        '''

        self.w = '未初始化参数(shape: NcolX*1)'
        self.b = '未初始化参数（截距项）'

        self.opt_method = opt_method

        self.max_iter = max_iter
        self.lr = lr

    def forward(self, X, w, b):
        '''前向传播（模型表达式）'''
        return self.sigmoid(np.dot(X, w) + b)

    @staticmethod
    def sigmoid(x):
        '''sigmoid激活函数'''
        return 1.0 / (1 + np.exp(-x))

    @staticmethod
    def add_const(X):
        '''X添加常数列，X为二维'''
        const = np.ones((X.shape[0], 1))
        return np.concatenate((X, const), axis=1)

    def fit(self, X_train, y_train):
        '''
        模型训练

        Parameters
        ----------
        X_train : pd.DataFrame, np.array
            训练集输入，每行一个样本
        y_train : pd.Series, pd.DataFrame, np.array
            训练集输出，每行一个样本
        '''

        X_train, y_train = np.array(X_train), np.array(y_train)
        NcolX = X_train.shape[1] # 特征数
        Xconst = self.add_const(X_train) # X添加常数项
        # y转化为二维
        if len(y_train.shape) == 1 or y_train.shape[1] == 1:
            y_train = y_train.reshape(-1, 1)

        # w和b初始化
        self.w = np.zeros((NcolX, 1))
        # 输入层——>隐藏层偏置b随机化
        self.b = 1

        # 梯度下降法
        if self.opt_method == 'gd':
            # 系数转为二维
            W = np.array(list(self.w[:,0]) + [self.b]).reshape(-1, 1)

            # 梯度下降更新W
            k = 1
            while k < self.max_iter:
                h = self.sigmoid(np.dot(Xconst, W)) # 前向传播
                grad = np.dot(Xconst.T, h - y_train) # 梯度
                W -= self.lr * grad

                k += 1

            self.w = W[:-1, :]
            self.b = W[-1][-1]

        # 牛顿法
        elif self.opt_method.lower() in ['newton', 'nt']:
            # 系数转为二维
            W = np.array(list(self.w[:,0]) + [self.b]).reshape(-1, 1)

            # 牛顿法更新W
            k = 1
            while k < self.max_iter:
                p = self.sigmoid(np.dot(Xconst, W)) # 前向传播
                grad = np.dot(Xconst.T, p - y_train) # 梯度
                # Hesse矩阵
                H = np.dot(Xconst.T, np.diag(p.reshape(-1,)))
                H = np.dot(H, np.diag(1 - p.reshape(-1,)))
                H = np.dot(H, Xconst)
                W = W - np.dot(np.linalg.inv(H), grad) # 梯度更新

                k += 1

            self.w = W[:-1, :]
            self.b = W[-1][-1]

        return self

    def predict_proba(self, X):
        '''概率预测'''
        y_pre_p = self.forward(X, self.w, self.b)
        return y_pre_p.reshape(-1,)

    def predict(self, X, p_cut=0.5):
        '''标签预测'''
        y_pre_p = self.predict_proba(X)
        y_pre = (y_pre_p >= p_cut).astype(int)
        return y_pre

#%%
if __name__ == '__main__':
    import pandas as pd
    from dramkit import TimeRecoder

    import matplotlib as mpl
    mpl.rcParams['font.sans-serif'] = ['SimHei']
    mpl.rcParams['font.serif'] = ['SimHei']
    mpl.rcParams['axes.unicode_minus'] = False
    import matplotlib.pyplot as plt


    tr = TimeRecoder()

    #%%
    data = pd.read_excel('./_test/test_data1.xlsx')
    # data = pd.read_excel('./_test/test_data2.xlsx')
    X = data[['x1', 'x2']]
    y = data['y']

    opt_method = 'gd'
    # opt_method = 'newton'

    max_iter = 10000
    lr = 0.001

    mdl = LRClassifierBinary(opt_method=opt_method, max_iter=max_iter, lr=lr)

    X_train, y_train = X, y
    mdl = mdl.fit(X, y)
    print('{}参数结果：\nw: \n{}\nb: {}'.format(opt_method, mdl.w, mdl.b))

    def plot_result(data, w, b, title=None):
        plt.figure(figsize=(10, 7))
        data0 = data[data['y'] == 0]
        data1 = data[data['y'] == 1]
        plt.plot(data0['x1'], data0['x2'], 'ob', label='y=0')
        plt.plot(data1['x1'], data1['x2'], 'or', label='y=1')

        x = np.arange(data['x1'].min(), data['x1'].max(), 0.1)
        y = (-b - w[0]*x) / w[1]
        plt.plot(x, y, '-')

        plt.legend(loc=0)

        if title:
            plt.title(title)

        plt.show()

    plot_result(data, mdl.w, mdl.b, opt_method)

    #%%
    tr.used()
